{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from keras.layers import *\n",
    "from keras.models import *\n",
    "from keras.optimizers import *\n",
    "from keras.regularizers import l2\n",
    "from keras.utils.vis_utils import model_to_dot, plot_model\n",
    "import keras.backend as K"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "characters = u'0123456789()+-*/=君不见黄河之水天上来奔流到海复回烟锁池塘柳深圳铁板烧; '\n",
    "n, width, height, n_class, channels = 100000, 900, 81, len(characters), 3\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ctc_lambda_func(args):\n",
    "    y_pred, labels, input_length, label_length = args\n",
    "    y_pred = y_pred[:, 2:, :]\n",
    "    return K.ctc_batch_cost(labels, y_pred, input_length, label_length)\n",
    "\n",
    "rnn_size = 128\n",
    "\n",
    "l2_rate = 1e-5\n",
    "\n",
    "input_tensor = Input((width, height, 3))\n",
    "x = input_tensor\n",
    "for i, n_cnn in enumerate([3, 4, 6]):\n",
    "    for j in range(n_cnn):\n",
    "        x = Conv2D(32*2**i, (3, 3), padding='same', kernel_initializer='he_uniform', \n",
    "                   kernel_regularizer=l2(l2_rate))(x)\n",
    "        x = BatchNormalization(gamma_regularizer=l2(l2_rate), beta_regularizer=l2(l2_rate))(x)\n",
    "        x = Activation('relu')(x)\n",
    "    x = MaxPooling2D((2, 2))(x)\n",
    "\n",
    "# x = AveragePooling2D((1, 2))(x)\n",
    "cnn_model = Model(input_tensor, x, name='cnn')\n",
    "\n",
    "input_tensor = Input((width, height, 3))\n",
    "x = cnn_model(input_tensor)\n",
    "\n",
    "conv_shape = x.get_shape().as_list()\n",
    "rnn_length = conv_shape[1]\n",
    "rnn_dimen = conv_shape[3]*conv_shape[2]\n",
    "\n",
    "print conv_shape, rnn_length, rnn_dimen\n",
    "\n",
    "x = Reshape(target_shape=(rnn_length, rnn_dimen))(x)\n",
    "rnn_length -= 2\n",
    "rnn_imp = 0\n",
    "\n",
    "x = Dense(rnn_size, kernel_initializer='he_uniform', kernel_regularizer=l2(l2_rate), bias_regularizer=l2(l2_rate))(x)\n",
    "x = BatchNormalization(gamma_regularizer=l2(l2_rate), beta_regularizer=l2(l2_rate))(x)\n",
    "x = Activation('relu')(x)\n",
    "# x = Dropout(0.2)(x)\n",
    "\n",
    "gru_1 = GRU(rnn_size, implementation=rnn_imp, return_sequences=True, name='gru1')(x)\n",
    "gru_1b = GRU(rnn_size, implementation=rnn_imp, return_sequences=True, go_backwards=True, name='gru1_b')(x)\n",
    "gru1_merged = add([gru_1, gru_1b])\n",
    "\n",
    "gru_2 = GRU(rnn_size, implementation=rnn_imp, return_sequences=True, name='gru2')(gru1_merged)\n",
    "gru_2b = GRU(rnn_size, implementation=rnn_imp, return_sequences=True, go_backwards=True, name='gru2_b')(gru1_merged)\n",
    "x = concatenate([gru_2, gru_2b])\n",
    "\n",
    "# x = Dropout(0.2)(x)\n",
    "x = Dense(n_class, activation='softmax', kernel_regularizer=l2(l2_rate), bias_regularizer=l2(l2_rate))(x)\n",
    "rnn_out = x\n",
    "base_model = Model(input_tensor, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "plot_model(base_model, show_shapes=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "plot_model(cnn_model, 'level2_cnn_model.png', show_shapes=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
